from mc import McEnv

from SCM import SCM
from HRL import HRL
from utils import str2bool

import argparse
import os
from mpi4py import MPI
from mpi4py.MPI import COMM_WORLD as comm
import numpy as np

import torch
try:
    from torch.utils.tensorboard import SummaryWriter
except ModuleNotFoundError:
    from tensorboardX import SummaryWriter

def do_sample_parallel(env, hrl, do_variables, data_size, rank, size, sample_step_size):
    #print(rank, 'sample data, do', [env.variable_names[v_idx] for v_idx in do_variables])
    x_datas, y_datas, z_datas = do_sample(env, hrl, do_variables, data_size//size, sample_step_size) 
    datas = (x_datas, y_datas, z_datas)
    datas = comm.gather(datas, root=0)
    if rank == 0:
        x_datas = torch.cat([datas[i][0] for i in range(size)], dim = 1).view(len(do_variables), data_size, env.variable_num)
        y_datas = torch.cat([datas[i][1] for i in range(size)], dim = 1).view(len(do_variables), data_size, env.variable_num)
        z_datas = torch.cat([datas[i][2] for i in range(size)], dim = 1).view(len(do_variables), data_size, env.aux_info_num)
    return x_datas, y_datas, z_datas
def do_sample(env, hrl, do_variables, expect_data_size, sample_step_size):
    action_variable = env.variable_num - 1
    data = [[[], [], []] for i in range(len(do_variables))]
    data_size = [0 for i in range(len(do_variables)*2)]
    complex_state, state = env.reset()
    aux_info = env.info['aux_info']
    while (min(data_size) < expect_data_size/2):
        operators = []
        for do_idx in range(len(do_variables)):
            do_variable = do_variables[do_idx]
            if do_variable == action_variable or state[do_variable] == 0:
                if data_size[do_idx*2] < expect_data_size/2:
                    operators.append(do_idx*2) 
                if data_size[do_idx*2+1] < expect_data_size/2:
                    operators.append(do_idx*2+1) 
        if set(operators) <= set([do_variables.index(action_variable)*2, do_variables.index(action_variable)*2+1]) and len(do_variables) > 1:
            complex_state, state = env.reset()
            aux_info = env.info['aux_info'] 
            continue 
        sampled_idx = torch.randint(0, len(operators), (1,)).item()
        do_type = operators[sampled_idx] % 2
        do_variable_idx = operators[sampled_idx]//2
        do_variable = do_variables[do_variable_idx]
        if do_type == 1 and do_variable != action_variable:
            goal = do_variable
            state, done, goal_achieved, _, complex_state = hrl.evaluate(env, hrl.k_level - 1, state, goal, complex_state)
            aux_info = env.env.aux_info
            if not goal_achieved:
                continue
        for step_idx in range(sample_step_size):
            action = torch.randint(0, env.action_space.n, (1,)).item()  
            next_complex_state, next_state, _, done, info = env.step(action)
            next_aux_info = info['aux_info']
            state[-1] = action
            next_state[-1] = action
            state_valid = not done and action > 3 and aux_info[0] > 0
            if state_valid and data_size[do_variable_idx*2+do_type] < expect_data_size/2:
                 if (state==next_state).all() and do_variable != env.variable_num - 1 and do_type==1: continue
                 data[do_variable_idx][0].append(state)
                 data[do_variable_idx][1].append(next_state)
                 data[do_variable_idx][2].append(aux_info)
                 data_size[do_variable_idx*2+do_type] += 1
            if done or next_state[do_variable] != state[do_variable] or len(data[do_variable_idx][0]) >= expect_data_size:
                state = next_state
                complex_state = next_complex_state
                aux_info = next_aux_info
                break
            else:
                state = next_state
                complex_state = next_complex_state
                aux_info = next_aux_info
    for i in range(len(do_variables)):
        assert(len(data[i][0])==expect_data_size)
    x_datas = [torch.stack(data[i][0]).view(expect_data_size, env.variable_num) for i in range(len(do_variables))]
    y_datas = [torch.stack(data[i][1]).view(expect_data_size, env.variable_num) for i in range(len(do_variables))]
    x_datas = torch.stack(x_datas).view(len(do_variables), expect_data_size, env.variable_num)
    y_datas = torch.stack(y_datas).view(len(do_variables), expect_data_size, env.variable_num)
    z_datas = [torch.stack(data[i][2]).view(expect_data_size, env.aux_info_num) for i in range(len(do_variables))]
    z_datas = torch.stack(z_datas).view(len(do_variables), expect_data_size, env.aux_info_num)
    return x_datas, y_datas, z_datas
def pre_train(env, scm, hrl, args):
    print(args.rank, 'train scm & hrl!')
    var_num = len(env.variable_ranges)

    do_variables = [var_num - 1] 
    dag = torch.zeros((var_num, var_num)).bool()
    if args.load_training_model:
        if args.rank == 0:
            dag=scm.load(args.model_path, str(args.scm_model_id))
        hrl.load(args.model_path, "HRL_{}".format(str(args.hrl_model_id)))
        for i in range(var_num):
             if hrl.goal_valid[i]:
                 do_variables.append(i)
    early_stop = 0
    train_depth = 0
    for iter_idx in range(args.scm_model_id, args.I):
        if not args.load_training_model or iter_idx > args.scm_model_id: 
            x_datas, y_datas, z_datas = do_sample_parallel(env, hrl, do_variables, args.Ns*args.B, args.rank, args.size, args.sample_step_size)
            if args.rank == 0:
                scm.train_f([x_datas.to(args.device), y_datas.to(args.device), z_datas.to(args.device)], do_variables, args) 
            if args.rank == 0:
                scm.train_s([x_datas.to(args.device), y_datas.to(args.device), z_datas.to(args.device)], do_variables, args) 
        # update dag
        if rank == 0:
            dag, var_depths = scm.get_DAG(args.causal_threshold, do_variables, env.variable_names, dag)
            cand_variables = []
            for v_idx in range(var_num):
                if dag[v_idx, :].any() and v_idx not in do_variables:# and var_depths[v_idx] == train_depth+1:
                    cand_en = True
                    for parent_idx in range(var_num):
                        if dag[v_idx, parent_idx] and parent_idx not in do_variables:
                            cand_en = False
                            break
                    if cand_en:
                        cand_variables.append(v_idx)
            print('iter', iter_idx)
            print('var_depths', var_depths)
            print('cand vars', [(env.variable_names[int(var)], var_depths[int(var)]) for var in cand_variables])
            if len(cand_variables) == 0:
                early_stop += 1
                if early_stop >= 2:
                    break
            elif 9 in cand_variables:
                break
            else:
                early_stop = 0
                old_dag = torch.zeros((var_num, var_num)).bool()
                train_variables_idx = 0 
                print(train_variables_idx, 'cand vars', [(env.variable_names[int(var)], var_depths[int(var)]) for var in cand_variables])
                scm.save(args.model_path, str(iter_idx), dag)
                assert(do_variables[0] == var_num - 1)
                trained_variables = hrl.train_goals(env, dag, var_depths, cand_variables, args)
                hrl.save(args.model_path, "HRL_{}".format(str(iter_idx)))
                do_variables = do_variables+trained_variables
                print(train_variables_idx, 'do vars', [(env.variable_names[int(var)], var_depths[int(var)]) for var in do_variables])
                print('valid:', hrl.goal_valid)
        if args.rank == 0:
            hrl_info = hrl.pack_params()
        else:
            hrl_info = None
        hrl_info = comm.bcast(hrl_info, root=0)
        do_variables = comm.bcast(do_variables, root=0)
        if args.rank != 0:
            hrl.unpack_params(hrl_info)

def train_task(env, hrl, args):
    task_log_directory = args.model_path+'/log/rank_'+str(args.rank)+'/'
    task_model_directory = args.model_path+'/rank_'+str(args.rank)+'/'
    if not os.path.exists(task_model_directory):
        os.makedirs(task_model_directory)
    task_log_writer = SummaryWriter(task_log_directory)
    print(rank, 'HRL task init!', '\nmodel path', task_model_directory, '\nlog path', task_log_directory) 
    v_num = env.variable_num
    k_level = hrl.k_level+1
    hrl.set_hierarchy(k_level, hrl.H_list+[5], hrl.gamma_list+[args.task_gamma])
 
    if args.rank == 0:
        print('k_level', hrl.k_level)
        for i in range(hrl.goal_dim):
           if hrl.goal_valid[i]:
               print(env.variable_names[i//len(env.variable_operators)], hrl.goal_depth[i])

    episode_time = 0 
    episode_idx = 0
    update_idx = 0 

    complex_state, state = env.reset() 
    # training procedure 
    for i_steps in range(1, args.task_train_steps+1):
        last_state = state.clone()
        state, done, _, env_infos, complex_state = hrl.run_HRL(env, hrl.k_level-1, state, hrl.goal_dim, False, complex_state, log_writer = task_log_writer)
        sum_times, distance = env_infos
        episode_time += sum_times
        if done:
            #print('train', episode_time)
            task_log_writer.add_scalar('train_episode_time', episode_time, episode_idx)
            task_log_writer.add_scalar('train_goal_distance', distance, episode_idx)
            
            episode_time = 0
            episode_idx += 1 
            if episode_idx % args.task_eval_interval == 0:
                best_episode_time = 100
                min_distance = 9
                goal_achieve_ratio = 0
                for test_idx in range(5):
                    done = False
                    episode_time = 0
                    while(not done):
                        last_state = state.clone()
                        state, done, _, env_infos, complex_state = hrl.evaluate(env, k_level-1, state, hrl.goal_dim, complex_state)
                        sum_times, distance = env_infos
                        episode_time += sum_times
                    min_distance = min(min_distance, distance)
                    if env.env.last_game_over:
                        best_episode_time = min(best_episode_time, episode_time)
                    goal_achieve_ratio = goal_achieve_ratio + (1 if env.env.last_game_over else 0)
                task_log_writer.add_scalar('eval_goal_distance', min_distance, episode_idx)
                task_log_writer.add_scalar('eval_episode_time', best_episode_time, episode_idx)
                task_log_writer.add_scalar('eval_goal_achieve_ratio', goal_achieve_ratio/5, episode_idx)
                if rank == 0:
                    print('i', i_steps, 'best_episode_time', best_episode_time)
                episode_time = 0
            
        
        if i_steps % args.update_steps == 0:
            losses = hrl.update(hrl.k_level-1, args.n_iter, args.batch_size, task_log_writer)
            update_idx += 1 

        if episode_idx % args.save_steps == 0:
            hrl.save(task_model_directory, "HRL_{}".format(str(episode_idx)))
            if rank == 0:
                print(i_steps, 'saved agent in', task_model_directory)

if __name__ == '__main__':
    rank = comm.Get_rank()
    size = comm.Get_size()

    parser = argparse.ArgumentParser(description='cdhrl')   
    # parameters for SCM
    #I, B, Fs, Qs, Ns, Cs, CausalThreshold
    parser.add_argument('--I', type=int, default=100)
    parser.add_argument('--B', type=int, default=256)
    parser.add_argument('--Fs', type=int, default=1000)
    parser.add_argument('--Qs', type=int, default=100)
    parser.add_argument('--Ns', type=int, default=20)
    parser.add_argument('--Cs', type=int, default=25)
    parser.add_argument('--l1_coef', type=float, default=0.05)
    parser.add_argument('--lmax_coef', type=float, default=0.05)
    parser.add_argument('--causal_threshold', type=float, default=0.8)
    parser.add_argument('--f_lr', type=float, default=5e-3)
    parser.add_argument('--s_lr', type=float, default=5e-2)
    parser.add_argument('--scm_model_id', type=int, default=0)
    parser.add_argument('--sample_step_size', type=int, default=50)
     
    # parameters for HRL
    parser.add_argument('--H', type=int, default=15)
    parser.add_argument('--eps', type=float, default=0.95)
    parser.add_argument('--lamda', type=float, default=0.2)
    parser.add_argument('--gamma', type=float, default=0.9)
    parser.add_argument('--task_gamma', type=float, default=0.95)
    parser.add_argument('--update_steps', type=int, default=3)
    parser.add_argument('--save_steps', type=int, default=1000)
    parser.add_argument('--n_iter', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--goal_train_steps', type=int, default=10000)
    parser.add_argument('--task_train_steps', type=int, default=1000000)
    parser.add_argument('--hrl_lr', type=float, default=0.0001)
    parser.add_argument('--train_threshold', type=float, default=0.5)
    parser.add_argument('--hrl_model_id', type=int, default=0)
    
    
    # parameters for logs
    parser.add_argument('--trained_model_path', type=str, default='pretrained_models/')
    parser.add_argument('--model_path', type=str, default=None)
    parser.add_argument('--gpu_num', type=int, default=4)
    parser.add_argument('--load_training_model', type=str2bool, default=False)
    parser.add_argument('--train_task', type=str2bool, default=False)
    parser.add_argument('--task_eval_interval', type=int, default=10)
    parser.add_argument('--norm_scale', type=int, default=10)


    args = parser.parse_args()

    args.device = torch.device('cuda:'+str(rank % args.gpu_num))
    args.rank = rank
    args.size = size
    if rank == 0:
        log_writer = SummaryWriter(args.model_path+'/log/')
        args.log_writer = log_writer
        print(args)

    args.device = torch.device('cuda:'+str((rank+1) % args.gpu_num))
    env = McEnv(seed = 42+args.rank**2)
    print(rank, 'init minecraft!')
    action_num = env.action_space.n
    variable_num = len(env.variable_ranges)
    variable_ranges = env.variable_ranges
    hrl = HRL(k_level = 1,
              H_list = [args.H],
              state_dim = env.observation_space.shape[0],
              action_dim = action_num,
              goal_dim = variable_num*len(env.variable_operators),
              lr = args.hrl_lr,
              eps = args.eps,
              goal_ranges = variable_ranges,
              device = args.device,
              lamda = args.lamda,
              gamma_list = [args.gamma],
              task_gamma = args.task_gamma,
              operators = env.variable_operators,
              norm_scale = args.norm_scale)
    print(rank, 'init hrl!')
    if rank == 0:
        scm = SCM(variable_ranges, args.s_lr, args.f_lr, args.device, aux_range_list=env.aux_info_ranges, variable_names = env.variable_names)
        print(rank, 'init scm model!')

    if args.train_task:
        hrl.load(args.trained_model_path, "HRL_minecraftpretrain6")
        train_task(env, hrl, args) 
    else:
        if rank == 0:
            pre_train(env, scm, hrl, args)
        else:
            pre_train(env, None, hrl, args)
